# Python script for rigid registration
data = pd.read_csv('./centered_X_Y.csv', sep=',')

slices = {}
for z in sorted(data['Z'].unique()):
    slices[z] = data[data['Z'] == z][['X_centered', 'Y_centered']].to_numpy()

# define rigid_transform
def rigid_transform(points, angle, translation):
    R = np.array([
        [np.cos(angle), -np.sin(angle)],
        [np.sin(angle), np.cos(angle)]
    ])
    return points @ R.T + translation

# define cost function
def compute_alignment(slice1, slice2):
    def cost_function(params):
        angle, tx, ty = params
        transformed_points = rigid_transform(slice2, angle, np.array([tx, ty]))
        distances, _ = KDTree(slice1).query(transformed_points)
        return np.sum(distances**2)

    init_params = [0, 0, 0]
    result = minimize(cost_function, init_params, method='Powell')
    return result.x  

# align
aligned_slices = {}
aligned_slices[1] = slices[1]
for z in range(2, len(slices) + 1):
    slice1 = aligned_slices[z - 1]
    slice2 = slices[z]
    angle, tx, ty = compute_alignment(slice1, slice2)
    aligned_slices[z] = rigid_transform(slice2, angle, np.array([tx, ty]))

# merge
points_3d = []
for z, points2d in aligned_slices.items():
    z_coords = np.ones((len(points2d), 1)) * z
    points_3d.append(np.hstack((points2d, z_coords)))
points_3d = np.vstack(points_3d)

# visualization
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(points_3d[:, 0], points_3d[:, 1], points_3d[:, 2], s=1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

